import matplotlib.pyplot as plt
from utils import test_classwise, train_samplewise_clapp
from data import load_SHD
from model import CLAPP_RSNN
import numpy as np
import torch
import seaborn as sns
from scipy.signal import savgol_filter
color_list = sns.color_palette('hls', 20)
device = 'cpu'
epochs = 1
n_inputs = 700 # 28*28 #34 * 34 * 2
n_hidden = 5 * [512]
n_outputs = 20
batch_size = 64
folder = 'models/'
model_name = folder + 'shd_5layer_norec.pt'
Spiking Heidelberg Digits
#train_loader, test_loader = load_PMNIST(n_time_bins, scale=0.9, patches=True) #load_NMNIST(n_time_bins, batch_size=batch_size)
n_time_bins = 100
train_loader, test_loader = load_SHD(batch_size=batch_size) #load_NMNIST(n_time_bins, batch_size=batch_size)
# Plot Example
for i in range(3):
frames, target = train_loader.next_item(-1, contrastive=True)
plt.figure()
plt.imshow(frames.squeeze(1).T)
plt.colorbar()
print(frames.shape, target)
/home/lars/ownCloud/ETH/Master/Project_2/SNN_CLAPP/data.py:17: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). self.y = torch.tensor(y)
torch.Size([100, 1, 700]) tensor([4.]) torch.Size([100, 1, 700]) tensor([10.]) torch.Size([100, 1, 700]) tensor([6.])
SNN = CLAPP_RSNN(n_inputs, n_hidden, n_outputs, beta=0.95, out_proj=False, device=device, recurrent=False, cat=False).to(device)
SNN.load_state_dict(torch.load(model_name, map_location='cpu'))
from_epoch = 5
train_clapp_loss = torch.load(model_name[:-3]+'_clapp_loss.pt', map_location='cpu')[int(from_epoch*len(train_loader)/batch_size):]
print(train_clapp_loss.shape)
for i in range(train_clapp_loss.shape[-1]):
plt.plot(from_epoch+(batch_size*np.arange(train_clapp_loss.shape[0])/len(train_loader)), savgol_filter(train_clapp_loss[:,i], 99, 1))
plt.legend([f'layer {i+1}' for i in range(len(SNN.clapp))])
# plt.ylim([-3.,-2])
plt.xlabel('Epoch')
plt.ylabel('Clapp Loss')
# train_shd_supervised_clapp(SNN, train_loader, 1, 'cpu')
torch.Size([101313, 5])
Text(0, 0.5, 'Clapp Loss')
clapp_activation, target_list, clapp_losses = test_classwise(SNN, test_loader, device, batch_size=batch_size, temporal=True)
print(f'Mean CLAPP Loss: {torch.stack(clapp_losses).mean(axis=0)}')
plt.plot(torch.stack(clapp_losses).detach().cpu()[:100,:])
print(torch.stack(clapp_losses).mean(axis=0))
Mean CLAPP Loss: tensor([-113.9548, -159.1156, -196.8741, -210.1194, -227.6216]) tensor([-113.9548, -159.1156, -196.8741, -210.1194, -227.6216])
layers = [SNN.clapp[0].fc.weight[:,:n_inputs]]
for i in range(1, len(SNN.clapp)):
layers.append(SNN.clapp[i].fc.weight[:,:n_hidden[i-1]] @ layers[-1])
for i in range(len(SNN.clapp)):
plt.figure()
plt.imshow(SNN.clapp[i].fc.weight.detach(), vmax=0.05, vmin=-0.05)
plt.colorbar()
# plt.figure()
# plt.imshow(SNN.clapp[i].pred.weight.detach(), vmax=0.5, vmin=-0.5)
# plt.colorbar()
for lay in layers:
plt.figure()
plt.imshow(lay.detach())
plt.colorbar()
print(len(clapp_activation))
hidden_activities_transformed = torch.stack(clapp_activation).swapaxes(0,1).reshape(len(SNN.clapp), -1, 512)
target_transformed = torch.stack(target_list).flatten()
print(hidden_activities_transformed.shape, target_transformed.shape)
from sklearn.decomposition import PCA
from umap import UMAP
from sklearn.manifold import TSNE
reduction = TSNE
colors = [color_list[i.int()] for i in target_transformed]
for i, hat in enumerate(hidden_activities_transformed):
reduct = reduction(n_components=2)
# hat_diff = hat[1:] - hat[:-1]
print(hat.shape)
hat_transform = reduct.fit_transform(hat.detach().cpu().numpy())
print(hat_transform.shape)
print(f'Total spikes in layer {i}: {hat.sum()}')
plt.figure()
plt.imshow(hat, vmax=15)
plt.colorbar()
plt.figure(figsize=(8,8))
plt.title(f'Hidden Activations {i}')
col = colors
for i in range(train_loader.num_classes):
col_indeces = np.argwhere(target_transformed.squeeze() == i).squeeze()
hat_col = hat_transform[col_indeces, :]
plt.scatter(hat_col[:,0], hat_col[:,1], s=6, color=color_list[i], label=i, alpha=1)
plt.legend()
36 torch.Size([5, 2304, 512]) torch.Size([2304]) torch.Size([2304, 512]) (2304, 2) Total spikes in layer 0: 4584499.0 torch.Size([2304, 512]) (2304, 2) Total spikes in layer 1: 4654315.0 torch.Size([2304, 512]) (2304, 2) Total spikes in layer 2: 5770904.0 torch.Size([2304, 512]) (2304, 2) Total spikes in layer 3: 5382842.0 torch.Size([2304, 512]) (2304, 2) Total spikes in layer 4: 6065976.0
from model import CLAPP_out
from tqdm.notebook import tqdm
def train_out_proj(epochs, batch):
# train output projections from all layers (and no layer)
losses_out = []
optimizers = []
out_projs = []
print_interval = 10*batch
# SNN.out_proj.out_proj.reset_parameters()
out_proj_0 = CLAPP_out(700, 20, beta=0.95)
optim_0 = torch.optim.SGD(out_proj_0.parameters(), lr=1e-4)
for lay in range(len(SNN.clapp)):
out_projs.append(CLAPP_out(512, 20, beta=0.95))
optimizers.append(torch.optim.SGD(out_projs[-1].parameters(), lr=1e-4))
optimizers[-1].zero_grad()
SNN.eval()
target = batch*[0]
acc = []
correct = (len(SNN.clapp) + 1)*[0]
with torch.no_grad():
pbar = tqdm(total=len(train_loader)*epochs)
while len(losses_out)*batch < len(train_loader)*epochs:
data, target = train_loader.next_item(target, contrastive=True)
SNN.reset(0)
logit_lists = [[] for lay in range(len(SNN.clapp)+1)]
data = data.squeeze()
for step in range(data.shape[0]):
data_step = data[step].float().to(device)
target = target.to(device)
logits, mem_his, clapp_loss = SNN(data_step, target, 0)
logts, _ = out_proj_0(data_step, target)
logit_lists[0].append(logts)
for lay in range(len(SNN.clapp)):
logts, _ = out_projs[lay](logits[lay], target)
logit_lists[lay+1].append(logts)
preds = [torch.stack(logit_lists[lay]).sum(axis=0) for lay in range(len(SNN.clapp)+1)]
# if pred.max() < 1: print(pred.max())
dL = [preds[lay].argmax(axis=-1) == target for lay in range(len(SNN.clapp)+1)]
out_proj_0.reset(1-dL[0].float())
for i, out_proj in enumerate(out_projs):
out_proj.reset(1-dL[i+1].float())
correct = [correct[lay] + dL[lay].sum() for lay in range(len(SNN.clapp)+1)]
losses_out.append(torch.tensor([torch.nn.functional.cross_entropy(preds[lay], target.squeeze().long()) for lay in range(len(SNN.clapp)+1)], requires_grad=False))
optim_0.step()
optim_0.zero_grad()
for opt in optimizers:
opt.step()
opt.zero_grad()
if len(losses_out)*batch % print_interval == 0:
pbar.write(f'Cross Entropy Loss: {(torch.stack(losses_out)[-400//batch:].sum(dim=0)/(400//batch)).numpy()}\n' +
f'Correct: {100*np.array(correct)/print_interval}%')
acc.append(np.array(correct)/print_interval)
correct = (len(SNN.clapp) + 1)*[0]
pbar.update(batch)
return [out_proj_0, *out_projs], np.asarray(acc), torch.stack(losses_out)
with torch.no_grad():
out_projs, acc, losses_out = train_out_proj(10, 80)
0%| | 0/81560 [00:00<?, ?it/s]
Cross Entropy Loss: [3.0592885 3.304688 3.2149575 3.2233822 2.3872805 3.4384217] Correct: [ 2.875 11.375 18. 15.75 27. 15. ]% Cross Entropy Loss: [3.0628736 3.1732297 2.8184757 2.1746995 1.9051759 2.5545602] Correct: [ 4.5 17.875 25.875 41.5 55.75 36.375]% Cross Entropy Loss: [3.1697598 2.983493 2.0265057 1.8123022 1.6135648 2.02384 ] Correct: [ 5.625 24.25 40.75 52.75 58.75 45.75 ]% Cross Entropy Loss: [3.1936996 2.6097703 2.0575373 1.5988777 1.4927404 1.4794581] Correct: [ 4.875 26. 45.25 59.125 60.5 61.5 ]% Cross Entropy Loss: [3.0790238 2.4289439 1.6033024 1.3053102 1.1618885 1.264855 ] Correct: [ 7.375 30.25 47.125 60.25 65.125 69. ]% Cross Entropy Loss: [3.1002088 2.2225099 1.4810988 1.0873731 0.81291664 1.113727 ] Correct: [ 8.625 33.625 51.125 62.75 72.875 72.625]% Cross Entropy Loss: [3.165258 2.046066 1.1968672 0.7255853 0.6862627 1.0534015] Correct: [ 7.875 39. 59.875 71.5 76.625 73.875]% Cross Entropy Loss: [3.1344364 2.179806 1.4766083 0.94342005 0.7400807 1.0225474 ] Correct: [ 7.5 37.875 53.75 71.625 78.5 73. ]% Cross Entropy Loss: [3.1401048 2.1995368 1.4009577 0.827705 0.5798048 0.8583808] Correct: [ 8. 36.5 54.75 70.875 79.25 73.875]% Cross Entropy Loss: [3.3764114 2.0199063 1.2509757 0.75755894 0.60538054 0.868153 ] Correct: [ 9.25 38. 56.75 72.5 79.25 76. ]% Cross Entropy Loss: [3.0230844 2.1256597 1.0526004 0.68123925 0.61539155 1.0028805 ] Correct: [11.625 40.625 63.125 77.25 80.875 75.5 ]% Cross Entropy Loss: [3.0273972 2.1773567 1.1713005 0.7561097 0.63606775 0.8772826 ] Correct: [ 8.5 34.25 61.5 75.75 79.5 76.125]% Cross Entropy Loss: [3.2097042 2.0563214 1.1072843 0.9506593 0.80205727 0.931341 ] Correct: [ 9.625 39.75 62.5 72.875 79. 77.5 ]% Cross Entropy Loss: [3.0743937 2.1993039 1.097712 0.6858389 0.5766239 0.80621386] Correct: [ 9.125 37.5 63.75 74.75 80.875 80.5 ]% Cross Entropy Loss: [3.0671723 2.0141437 1.1794173 0.7060026 0.5225506 0.5934661] Correct: [ 9.75 38.5 59.625 75.125 81.625 82.5 ]% Cross Entropy Loss: [2.9994457 2.1309588 1.1689099 0.65776 0.5776539 0.7486261] Correct: [ 9.125 39.5 63.625 78.625 82.625 80.75 ]% Cross Entropy Loss: [3.0490613 1.7696857 0.87877256 0.6178111 0.5513927 0.79168826] Correct: [ 9.125 41.5 62. 77.375 79.75 80. ]% Cross Entropy Loss: [3.0263753 1.8986218 1.0611737 0.69556695 0.6273117 0.7438471 ] Correct: [ 9.75 41.75 65.625 75. 80.875 80.25 ]% Cross Entropy Loss: [3.0008912 2.007826 0.98661995 0.6089993 0.5939046 0.6744814 ] Correct: [10.125 38. 67.5 77.75 82.375 82.125]% Cross Entropy Loss: [3.1105425 1.6115786 0.9665772 0.5954911 0.5312096 0.57376426] Correct: [10.5 44.25 63.75 77.75 81.625 80.5 ]% Cross Entropy Loss: [3.0068498 2.02105 0.8766573 0.6144441 0.60451126 0.6416334 ] Correct: [ 9.75 41. 64. 77.25 81. 81.375]% Cross Entropy Loss: [3.160195 1.950182 0.9932415 0.68624324 0.58092666 0.7089085 ] Correct: [ 8.625 38.25 64.5 78.25 84. 83.875]% Cross Entropy Loss: [3.147021 1.6818091 0.97858095 0.6021178 0.50018394 0.4438458 ] Correct: [ 8.625 43.875 67.25 78.875 82.625 84.625]% Cross Entropy Loss: [3.0289884 1.6315044 1.0720955 0.6997339 0.64312404 0.6109128 ] Correct: [ 9.375 44.875 65.875 77.375 81.25 83.75 ]% Cross Entropy Loss: [3.1076818 1.7735803 0.9963292 0.6965545 0.5888001 0.67020994] Correct: [ 9.75 39.25 68.625 78.5 81.125 83.375]% Cross Entropy Loss: [3.0140233 1.7964083 1.0170692 0.71051633 0.6399194 0.62391514] Correct: [ 7.25 44.125 64.625 78.5 80. 82.5 ]% Cross Entropy Loss: [3.2983482 1.4336803 0.8651347 0.59745806 0.59583783 0.47938347] Correct: [ 7.75 49. 69.625 79.75 81.625 85.5 ]% Cross Entropy Loss: [3.0227172 1.6308626 0.8873695 0.5688719 0.5236004 0.4128879] Correct: [10. 44.5 66.375 78.625 80.625 85. ]% Cross Entropy Loss: [3.1221206 1.8216858 1.086255 0.5396703 0.4476328 0.45825785] Correct: [10.375 43.625 62.75 80.125 83.5 87. ]% Cross Entropy Loss: [3.0165932 1.594445 0.82931024 0.62122405 0.56433904 0.53335106] Correct: [10.75 45.875 69.25 76.625 83.125 84. ]% Cross Entropy Loss: [2.976345 1.7552588 0.8727827 0.5439472 0.487845 0.48407856] Correct: [ 8.625 40.75 68.875 80.875 82.875 85.375]% Cross Entropy Loss: [3.1561894 2.109057 1.0499103 0.649124 0.6493728 0.6012848] Correct: [10.75 39.375 63. 81.75 81.875 84.5 ]% Cross Entropy Loss: [3.0022197 1.6594759 0.88983524 0.6084543 0.4750777 0.63930285] Correct: [ 8.875 40.625 65.375 80.75 83. 84.625]% Cross Entropy Loss: [3.0250363 1.6357654 0.7875086 0.5366533 0.4635105 0.38862798] Correct: [ 8.25 43.875 69.875 79.75 84.5 85.625]% Cross Entropy Loss: [3.0546546 1.8087715 1.1468611 0.7362593 0.52866495 0.54879266] Correct: [10.5 44.375 58.875 74.125 82.5 85.75 ]% Cross Entropy Loss: [2.9471247 1.7242441 0.8533147 0.6239732 0.46459252 0.4667224 ] Correct: [ 8.75 43. 69.625 79.375 83.75 85.625]% Cross Entropy Loss: [2.9403381 2.424672 1.1020358 0.64152163 0.48979133 0.54925853] Correct: [ 9.75 39.75 65.625 78.5 85. 84.75 ]% Cross Entropy Loss: [2.9274268 1.7775714 1.0676163 0.7637645 0.59661067 0.696983 ] Correct: [12.5 47.125 67. 76.125 83.25 82.375]% Cross Entropy Loss: [3.0133202 2.0769572 1.0828611 0.6973138 0.6013626 0.6315218] Correct: [13.375 41.875 62.5 74.75 82.375 83.75 ]% Cross Entropy Loss: [3.2284603 1.7134516 0.9874684 0.72427064 0.5652125 0.6226713 ] Correct: [12. 41.125 67.25 74.625 81.125 83.125]% Cross Entropy Loss: [3.4371686 1.851958 0.8499395 0.51909226 0.45415252 0.37409484] Correct: [12.75 38.25 67.5 78.625 80.625 84. ]% Cross Entropy Loss: [3.0054333 1.6608994 0.8732424 0.6181227 0.50393623 0.5524405 ] Correct: [11.5 46.25 72.25 80.75 86. 84.75]% Cross Entropy Loss: [3.0089493 1.8744663 1.0633396 0.6948918 0.62322265 0.69205284] Correct: [10.375 42.5 68.625 77.75 82.25 83. ]% Cross Entropy Loss: [3.0932999 1.7876904 0.9634226 0.5683807 0.48349482 0.42919135] Correct: [ 9.125 43. 67.5 80.625 85.25 86.5 ]% Cross Entropy Loss: [3.19741 2.1482918 0.9270364 0.50343615 0.49512953 0.43248376] Correct: [ 9.875 39.625 64.375 82.125 84. 85.875]% Cross Entropy Loss: [2.9276383 1.4919791 0.8011287 0.53904796 0.465405 0.44272572] Correct: [13.875 48.875 70.5 79.625 84.375 87.25 ]% Cross Entropy Loss: [3.0575027 1.6998974 1.2657552 0.8699988 0.6387011 0.5657667] Correct: [14. 49.375 64.125 77.25 81.875 84.875]% Cross Entropy Loss: [3.0627096 1.6546942 1.0380795 0.6233481 0.54562175 0.5222074 ] Correct: [14.5 44.75 69. 81.375 83.25 84.625]% Cross Entropy Loss: [2.8974407 1.957535 1.0858295 0.80741656 0.6602985 0.5484301 ] Correct: [11.125 42.25 65.25 75.875 81.5 83. ]% Cross Entropy Loss: [3.4451847 1.9205099 0.8657915 0.4949943 0.4703858 0.42402235] Correct: [11.875 48.875 71.5 82.625 84.25 87.25 ]% Cross Entropy Loss: [2.9432416 2.0023696 0.7916729 0.5706314 0.5208697 0.5321156] Correct: [12.5 38.625 68. 82.5 83.5 84.875]% Cross Entropy Loss: [3.030706 1.6219174 0.881166 0.56137925 0.49297038 0.48670998] Correct: [ 9.875 46.625 67.5 78.5 81.125 85.5 ]% Cross Entropy Loss: [3.1968966 1.8695322 1.1020805 0.77584946 0.5797532 0.51994276] Correct: [10.25 38.25 61.625 75.875 80.625 83.125]% Cross Entropy Loss: [3.040295 1.6315901 0.8334408 0.6557951 0.50152415 0.61000997] Correct: [11.375 45.25 71. 80.5 85.625 84.625]% Cross Entropy Loss: [2.9667509 1.7627857 0.8647931 0.60715836 0.60728353 0.5163029 ] Correct: [ 9.75 46.375 72.25 80.5 79.75 85.75 ]% Cross Entropy Loss: [2.9047413 1.8478705 0.80170983 0.5453348 0.5021038 0.47545013] Correct: [10. 45.5 69.5 80.5 83.25 84.375]% Cross Entropy Loss: [2.8838997 1.8833306 0.97277486 0.58850855 0.4655931 0.4989623 ] Correct: [13.5 39.875 68.75 78.625 81.75 82.875]% Cross Entropy Loss: [2.9431396 1.7508295 0.7993283 0.6839116 0.5515451 0.5776764] Correct: [14.75 44.375 69.5 79.875 83.125 84.125]% Cross Entropy Loss: [3.1627584 1.8046802 1.0227302 0.54682213 0.47725177 0.4406547 ] Correct: [13. 44. 68.125 82.125 85.375 86.375]% Cross Entropy Loss: [3.0061405 1.7326323 0.8558343 0.55225575 0.4413685 0.42838746] Correct: [ 9.875 43. 64.875 79.75 85.625 84.625]% Cross Entropy Loss: [3.0251403 2.1437316 0.87921894 0.5408951 0.494903 0.49094287] Correct: [11. 43. 67.875 79.5 83.375 83.375]% Cross Entropy Loss: [2.9322987 1.652333 0.6880803 0.49353772 0.35898796 0.36407328] Correct: [15.125 45.75 76.25 82.625 85.375 87.375]% Cross Entropy Loss: [3.200056 1.6237847 0.82318056 0.5221631 0.42583814 0.43888077] Correct: [11.5 47.875 73.5 81.125 87. 85.5 ]% Cross Entropy Loss: [3.0273337 1.9252964 0.876787 0.67927283 0.5911812 0.54684985] Correct: [12.625 40.625 70. 78. 81.75 85.375]% Cross Entropy Loss: [3.2836616 1.6204703 0.9664984 0.62907064 0.49594408 0.533266 ] Correct: [10. 46.875 69. 79. 84.125 84. ]% Cross Entropy Loss: [3.0609908 1.8892581 1.041799 0.6662888 0.41626057 0.4137256 ] Correct: [14.75 44.625 67.25 77.75 85.25 88. ]% Cross Entropy Loss: [2.8581684 1.8039081 1.0574852 0.60612345 0.51491916 0.54161006] Correct: [12.125 42.125 65.25 79. 82.5 83.625]% Cross Entropy Loss: [3.1883678 1.5175365 0.7681623 0.5504075 0.5069863 0.4727601] Correct: [13.875 45. 71.5 80.125 84.125 85.375]% Cross Entropy Loss: [4.5461597 1.9413345 1.2279657 0.7244 0.50492215 0.47895068] Correct: [10.75 41.625 64. 77.625 84.125 84.375]% Cross Entropy Loss: [2.8848042 2.0539155 0.86577195 0.50719327 0.3743063 0.42951965] Correct: [12.375 42.375 72.125 82.625 85.75 86.625]% Cross Entropy Loss: [3.1026292 2.6553998 1.1793092 0.76385534 0.7186745 0.67797405] Correct: [10.75 39.625 66.75 81.625 82. 84.75 ]% Cross Entropy Loss: [2.9684005 1.5610086 0.82370996 0.53371507 0.48485833 0.48094493] Correct: [13.625 46.75 69.25 81.125 83.5 85.75 ]% Cross Entropy Loss: [3.4059956 2.1697114 1.169297 0.5937111 0.5299756 0.5177182] Correct: [10.5 40.875 66.25 79.75 83.875 84.625]% Cross Entropy Loss: [2.981723 1.5343466 0.7402752 0.5267539 0.39910927 0.41904512] Correct: [12.125 46.125 74.375 82.625 87. 88. ]% Cross Entropy Loss: [2.8993185 1.6620448 0.8629214 0.6263448 0.6495975 0.5531028] Correct: [13.125 43.125 69.125 80.125 82.5 85.875]% Cross Entropy Loss: [2.8841465 1.6387984 0.6655936 0.46667004 0.40492797 0.3920607 ] Correct: [10.625 41.375 72.5 82.125 83.875 85.25 ]% Cross Entropy Loss: [3.0970635 1.6737862 0.87659276 0.5304945 0.52965105 0.4708584 ] Correct: [12.75 47. 71.375 82.75 83.875 85.625]% Cross Entropy Loss: [3.2678325 1.5704967 0.83065474 0.5681037 0.5583045 0.5014496 ] Correct: [14.5 46.125 73.875 81.125 80.875 84.375]% Cross Entropy Loss: [3.4222984 1.5396487 0.9070364 0.6226528 0.48775855 0.51048726] Correct: [10.875 45.5 67.5 79.75 84.375 84.375]% Cross Entropy Loss: [3.081642 1.6114861 1.1220353 0.6021444 0.47716507 0.49586043] Correct: [11.75 48.125 63.375 80.375 84.5 85.75 ]% Cross Entropy Loss: [3.061386 1.5051836 0.6441544 0.6589028 0.51302254 0.47163734] Correct: [ 8.875 43.75 72.375 80.25 85.5 87.25 ]% Cross Entropy Loss: [2.9992366 1.9130338 1.0048182 0.68723977 0.5332165 0.54111946] Correct: [12.5 40.375 72.375 81.5 83.5 86.375]% Cross Entropy Loss: [3.0111508 2.0618596 0.878984 0.5489829 0.48416367 0.50063264] Correct: [15.375 37.5 67.75 80.5 83.25 84.875]% Cross Entropy Loss: [3.0650299 1.8617165 0.79462755 0.5597135 0.55789185 0.4707482 ] Correct: [13.125 48.375 71.375 82.25 82.75 85.625]% Cross Entropy Loss: [2.86556 1.4245446 0.81697196 0.60661584 0.46801043 0.47327223] Correct: [13.125 50.5 68.625 80. 84.125 86.625]% Cross Entropy Loss: [3.5813553 1.7462721 0.6996439 0.45092177 0.3761281 0.4439034 ] Correct: [12.25 48.125 71. 83.5 85. 85.375]% Cross Entropy Loss: [2.8774676 1.9664905 1.0307274 0.50789547 0.43523592 0.43112522] Correct: [13.25 40.25 70.5 82.125 84.5 85.625]% Cross Entropy Loss: [2.9437509 1.6426048 0.9455849 0.67471564 0.6092092 0.5702437 ] Correct: [13.125 44.125 68.625 79.875 80.25 83.5 ]% Cross Entropy Loss: [2.9589424 1.5227692 0.78685045 0.5386065 0.46356648 0.49786806] Correct: [14.375 45.125 71.125 82.125 85.75 86.5 ]% Cross Entropy Loss: [3.0250165 1.7240454 0.7977293 0.54554 0.43309125 0.3884979 ] Correct: [10.625 47.375 73. 82. 84.75 86.75 ]% Cross Entropy Loss: [3.129272 1.967033 0.9577776 0.58596694 0.49101862 0.5307419 ] Correct: [15.25 44.125 68.125 79.875 84.25 84.25 ]% Cross Entropy Loss: [2.9101024 1.7318776 1.1131595 0.5208601 0.4648195 0.4657405] Correct: [17.25 46.125 67.5 84.375 84.375 85.375]% Cross Entropy Loss: [3.5960832 1.9333878 0.7571655 0.60292566 0.50618345 0.4628747 ] Correct: [10. 40.25 70.375 79.125 82.25 85.625]% Cross Entropy Loss: [3.029155 1.4889362 0.73992825 0.5707727 0.45236874 0.43567744] Correct: [12.25 48.875 73. 82.125 83.125 87. ]% Cross Entropy Loss: [2.8766136 1.6459885 0.84303045 0.56460893 0.3991136 0.41761708] Correct: [13.875 42.875 69.875 79.625 84. 86. ]% Cross Entropy Loss: [3.4214454 1.5972183 0.8658182 0.64118004 0.56899863 0.5759396 ] Correct: [13.25 50.875 71.625 79.875 83.625 82.625]% Cross Entropy Loss: [2.9154363 1.7812408 0.7149235 0.5127851 0.459591 0.45135212] Correct: [10.5 41.375 70.625 81. 86.125 86.625]% Cross Entropy Loss: [2.9860187 1.6418374 0.7283004 0.47905737 0.4052481 0.3719082 ] Correct: [10.125 48.5 74.5 82.25 85.375 87.625]% Cross Entropy Loss: [2.9465632 2.1019638 0.7686553 0.47209206 0.4784643 0.434412 ] Correct: [10.75 40.5 70.75 82.875 85.125 86.625]% Cross Entropy Loss: [2.9009266 1.6417015 0.8166645 0.58031374 0.4775436 0.4580237 ] Correct: [14.625 46.25 67.25 80.375 83.5 84.375]% Cross Entropy Loss: [2.90346 2.0258083 0.9324237 0.5852523 0.47885126 0.424465 ] Correct: [10.125 43.75 71.5 79.625 84.5 84.75 ]% Cross Entropy Loss: [2.9447227 2.3738275 1.2133219 0.5977822 0.43241492 0.36062104] Correct: [10.25 43.5 67.5 77.5 84.5 87.75]%
print(f'Accuracy of last quarter: {100*acc[-len(acc)//4:].mean(axis=0)}%')
plt.figure()
plt.plot(np.asarray(acc)*100)
plt.ylabel('Accuracy [%]')
plt.xlabel('Training Step [x500]')
labels = ['From Inputs directly', *[f'From Layer {i+1}' for i in range(len(SNN.clapp))]]
plt.legend(labels)
# plt.ylim([75, 90])
plt.figure()
print(losses_out.shape)
for i in range(losses_out.shape[1]):
plt.plot(np.arange(len(losses_out))/len(train_loader), savgol_filter(losses_out[:,i], 99, 1), label=labels[i])
plt.ylabel('Cross Entropy Loss')
plt.xlabel('Training Step')
plt.legend();
Accuracy of last quarter: [12.49038462 44.97115385 70.21153846 81.02884615 83.99038462 85.66346154]% torch.Size([1020, 6])
from tqdm.notebook import trange
correct = torch.zeros(len(out_projs))
for out_proj in out_projs:
out_proj.eval()
SNN.eval()
pred_matrix = torch.zeros(n_outputs, n_outputs)
for idx in trange(0, len(test_loader), batch_size):
for out_proj in out_projs:
out_proj.reset()
SNN.reset(0)
inp, target = test_loader.x[idx:idx+batch_size], test_loader.y[idx:idx+batch_size]
logits = len(out_projs)*[torch.zeros((inp.shape[0],20))]
for step in range(inp.shape[1]):
data_step = inp[:,step].float().to(device)
spk_step, _, _ = SNN(data_step, None, 0)
spk_step = [data_step, *spk_step]
for i, out_proj in enumerate(out_projs):
out, _ = out_proj(spk_step[i], target)
logits[i] = logits[i] + out
for i, logit in enumerate(logits):
pred = logit.argmax(axis=-1)
correct[i] += int((pred == target).sum())
# for the last layer create the prediction matrix
for j in range(pred.shape[0]):
pred_matrix[int(target[j]), int(pred[j])] += 1
correct /= len(test_loader)
print('Directly from inputs:')
print(f'Accuracy: {100*correct[0]:.2f}%')
for i in range(len(out_projs)-1):
print(f'From layer {i+1}:')
print(f'Accuracy: {100*correct[i+1]:.2f}%')
plt.imshow(pred_matrix, origin='lower')
plt.title('Prediction Matrix for the final layer')
plt.xlabel('Prediction')
plt.ylabel('Target')
plt.xticks([i for i in range(n_outputs)])
plt.yticks([i for i in range(n_outputs)])
plt.colorbar();
0%| | 0/36 [00:00<?, ?it/s]
Directly from inputs: Accuracy: 12.10% From layer 1: Accuracy: 38.21% From layer 2: Accuracy: 51.94% From layer 3: Accuracy: 61.88% From layer 4: Accuracy: 67.14% From layer 5: Accuracy: 67.80%
print(torch.diag(pred_matrix).sum()/pred_matrix.sum())
print(pred_matrix.diag().sum(), pred_matrix.sum(), len(test_loader))
from snntorch import spikeplot as spkplt
SNN.eval()
for out_proj in out_projs:
out_proj.eval()
data_lastlayer = torch.zeros(n_time_bins, 512)
data_out = torch.zeros(n_time_bins, 20)
for idx in range(5):
for step in range(inp.shape[1]):
data_step = inp[:,step].float().to(device)
spk_step, _, _ = SNN(data_step, None, 0)
data_lastlayer[step] = spk_step[-1][idx]
out, _ = out_projs[-1](spk_step[-1][idx], target)
data_out[step] = out[0]
print(target[idx])
fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
print(data_lastlayer.mean(), data_out.mean())
spkplt.raster(data_lastlayer, ax, s=1.5, color='black')
fig = plt.figure(facecolor="w", figsize=(10, 5))
ax = fig.add_subplot(111)
spkplt.raster(data_out, ax, s=5, color='black')
tensor(0.6780) tensor(1535.) tensor(2264.) 2264 tensor(18.) tensor(0.0378) tensor(0.0255) tensor(8.) tensor(0.0652) tensor(0.0240) tensor(4.) tensor(0.0674) tensor(0.0235) tensor(1.) tensor(0.0632) tensor(0.0210) tensor(8.) tensor(0.0472) tensor(0.0215)